{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "04Ug2C7GHUNO"
      },
      "outputs": [],
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ieA4poCNHdSY"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SVCGxEikHgUc"
      },
      "source": [
        "# MatFormer Lab:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dOG0dm7LHmBE"
      },
      "source": [
        " <table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3n]MatFormer_Lab.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4o6A5O7ofZTH"
      },
      "source": [
        "Gemma 3n is a multimodal, multilingual model from the Gemma family of models. You can read about Gemma 3n in [its docs](https://ai.google.dev/gemma/docs/gemma-3n) and [launch blog post](https://developers.googleblog.com/en/introducing-gemma-3n-developer-guide). It is a unique model that is natively elastic, which means you have nested models! Gemma 3n was trained as a E4B model (effectively loads 4B parameters) with 35 layers and a 16,384 FFN hidden dimension. It has a E2B (30 layers and 8,192 FFN hidden dimension) nested within it and jointly trained as a MatFormer.\n",
        "\n",
        "[MatFormer](https://arxiv.org/abs/2310.07707) (🪆Matryoshka Transformer) architecture is a novel nested transformer built for elastic inference. Think of it like Matryoshka dolls: a larger model contains smaller, fully functional versions of itself. This approach extends the concept of [Matryoshka Representation Learning](https://huggingface.co/papers/2205.13147) from just embeddings to all transformer components.\n",
        "\n",
        "Saving memory by having E2B nested within E4B is practically useful, but what makes MatFormer so powerful is its capability to smoothly span the entire Pareto-optimal accuracy-vs-model size cruve between E2B and E4B without any additional training. Using a simple technique called **Mix-n-Match** one can extract a model of any size between E2B and E4B from the main E4B model.\n",
        "\n",
        "Why would you slice a model? Given your specific deployment requirements, the E2B and E4B may not be the right fit. You may want to get an E3B, for example, which has quality higher than the E2B while requiring less compute than E4B.\n",
        "\n",
        "**In this notebook, you get to play with MatFormers and Mix-n-Match. You will specify which configuration you want for the submodel based on FFN dimension and skp layers, and then you will export the model to Hugging Face, enabling you to use with your favorite tools.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ys5kfGnpHsDV"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "YBXCVVERIGzM"
      },
      "outputs": [],
      "source": [
        "# @title Install dependencies\n",
        "# @markdown Run this cell to install all the required dependencies. In particular, you'll need Hugging Face `transformers` and `timm` versions that support Gemma 3n. Note that you may need to restart the notebook after executing the following cell.\n",
        "\n",
        "# Install a transformers version that supports Gemma 3n (>= 4.53)\n",
        "!pip install \"transformers>=4.53\" \"timm>=1.0.16\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GFh8AhruHeMl"
      },
      "outputs": [],
      "source": [
        "# @title Login to Hugging Face\n",
        "# @markdown This is required so you can push the model to Hugging Face. You also need to make sure you have access to the Gemma 3n model repositories.\n",
        "\n",
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "yKFl4n5cLF9J"
      },
      "outputs": [],
      "source": [
        "# @title Import and Export Options\n",
        "# @markdown The MatFormer Lab allows you to load a Gemma 3n 4B checkpoint (either pre-trained or instruct-tuned) and to slice it. Below, please specify:\n",
        "\n",
        "# @markdown * The original repository ID from the checkpoint in Hugging Face\n",
        "\n",
        "# @markdown * A local path where the model will be saved\n",
        "\n",
        "# @markdown * A name of a repository to push the new checkpoint to\n",
        "\n",
        "original_model_id = \"google/gemma-3n-E4B-it\" # @param [\"google/gemma-3n-E4B-it\", \"google/gemma-3n-E4B-pt\"]\n",
        "local_output_path = \"my_modified_gemma_3n_model\" # @param {type:\"string\"}\n",
        "push_hf_repo_id = \"username/test-submodel\"  # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P16x_lxGCcgc"
      },
      "source": [
        "### Slicing configuration\n",
        "\n",
        "As part of the Gemma 3n release, we share optimal slicing configurations as a [Hugging Face dataset repository](https://huggingface.co/datasets/google/gemma3n-slicing-configs), although you can also explore with your own configurations below.\n",
        "\n",
        "Each configuration specifies:\n",
        "* The hidden dimensions of the FFN\n",
        "* Which layers, if any, to skip\n",
        "* A MMLU accuracy associated with the pre-trained checkpoints\n",
        "\n",
        "`layer-level` and `block-level` configurations are a result of varying the hidden dimensions of the FFN either at a `layer-level` (fine-grained) or at a `block-level` (4 local + 1 global layers).\n",
        "\n",
        "At a `layer-level`, we found that having the global layers (instead of local layers) have higher capacity helps with accuracy for the same model size.\n",
        "\n",
        "At a `block-level`, we found that the block skipped for E2B (ie., layers 20-24) would benfit from higher capacity when not skipped and earlier blocks can work well with lower-capacity compared to the later blocks.\n",
        "\n",
        "**We invite the community to find even better configurations that lie on the Pareto-optimal curve between E2B and E4B!**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UHdEebcGYats"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "summary": "{\n  \"name\": \"df\",\n  \"rows\": 16,\n  \"fields\": [\n    {\n      \"column\": \"name\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 16,\n        \"samples\": [\n          \"Main model\",\n          \"Config for official E2B Model\",\n          \"Config for E2.98B (layer-level)\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"# Layers\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1,\n        \"min\": 30,\n        \"max\": 35,\n        \"num_unique_values\": 2,\n        \"samples\": [\n          30,\n          35\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"# Effective Params (B)\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.6242168426650042,\n        \"min\": 1.91,\n        \"max\": 3.98,\n        \"num_unique_values\": 15,\n        \"samples\": [\n          3.79,\n          2.73\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"MMLU PT accuracy\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 14,\n        \"samples\": [\n          \"54.50%\",\n          \"60.70%\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"FFN Hidden Dims\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 16,\n        \"samples\": [\n          \"[2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8]\",\n          \"[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4]\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"Layers Skipped\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 1,\n        \"samples\": [\n          \"[20, 21, 22, 23, 24]\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}",
              "type": "dataframe",
              "variable_name": "df"
            },
            "text/html": [
              "\n",
              "  <div id=\"df-97ccff80-0790-4fa4-be96-a87f688fd2e6\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>name</th>\n",
              "      <th># Layers</th>\n",
              "      <th># Effective Params (B)</th>\n",
              "      <th>MMLU PT accuracy</th>\n",
              "      <th>FFN Hidden Dims</th>\n",
              "      <th>Layers Skipped</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Main model</td>\n",
              "      <td>35</td>\n",
              "      <td>3.98</td>\n",
              "      <td>62.30%</td>\n",
              "      <td>[2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2...</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Config for official E2B Model</td>\n",
              "      <td>30</td>\n",
              "      <td>1.91</td>\n",
              "      <td>50.90%</td>\n",
              "      <td>[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...</td>\n",
              "      <td>[20, 21, 22, 23, 24]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Config for E1.96B (layer-level)</td>\n",
              "      <td>30</td>\n",
              "      <td>1.96</td>\n",
              "      <td>53.40%</td>\n",
              "      <td>[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...</td>\n",
              "      <td>[20, 21, 22, 23, 24]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Config for E2.54B (layer-level)</td>\n",
              "      <td>35</td>\n",
              "      <td>2.54</td>\n",
              "      <td>55.40%</td>\n",
              "      <td>[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Config for E2.69B (layer-level)</td>\n",
              "      <td>35</td>\n",
              "      <td>2.69</td>\n",
              "      <td>57.70%</td>\n",
              "      <td>[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-97ccff80-0790-4fa4-be96-a87f688fd2e6')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-97ccff80-0790-4fa4-be96-a87f688fd2e6 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-97ccff80-0790-4fa4-be96-a87f688fd2e6');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    <div id=\"df-fd3958cf-454b-4213-857a-c40990ceebbb\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-fd3958cf-454b-4213-857a-c40990ceebbb')\"\n",
              "                title=\"Suggest charts\"\n",
              "                style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "      <script>\n",
              "        async function quickchart(key) {\n",
              "          const quickchartButtonEl =\n",
              "            document.querySelector('#' + key + ' button');\n",
              "          quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "          quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "          try {\n",
              "            const charts = await google.colab.kernel.invokeFunction(\n",
              "                'suggestCharts', [key], {});\n",
              "          } catch (error) {\n",
              "            console.error('Error during call to suggestCharts:', error);\n",
              "          }\n",
              "          quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "          quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "        }\n",
              "        (() => {\n",
              "          let quickchartButtonEl =\n",
              "            document.querySelector('#df-fd3958cf-454b-4213-857a-c40990ceebbb button');\n",
              "          quickchartButtonEl.style.display =\n",
              "            google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "        })();\n",
              "      </script>\n",
              "    </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "text/plain": [
              "                              name  # Layers  # Effective Params (B)  \\\n",
              "0                       Main model        35                    3.98   \n",
              "1    Config for official E2B Model        30                    1.91   \n",
              "2  Config for E1.96B (layer-level)        30                    1.96   \n",
              "3  Config for E2.54B (layer-level)        35                    2.54   \n",
              "4  Config for E2.69B (layer-level)        35                    2.69   \n",
              "\n",
              "  MMLU PT accuracy                                    FFN Hidden Dims  \\\n",
              "0           62.30%  [2_048 * 8, 2_048 * 8, 2_048 * 8, 2_048 * 8, 2...   \n",
              "1           50.90%  [2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...   \n",
              "2           53.40%  [2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...   \n",
              "3           55.40%  [2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...   \n",
              "4           57.70%  [2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2...   \n",
              "\n",
              "         Layers Skipped  \n",
              "0                   NaN  \n",
              "1  [20, 21, 22, 23, 24]  \n",
              "2  [20, 21, 22, 23, 24]  \n",
              "3                   NaN  \n",
              "4                   NaN  "
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import pandas as pd\n",
        "\n",
        "df = pd.read_csv(\"hf://datasets/google/gemma3n-slicing-configs/configs.csv\")\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ldNQRKptGnbU"
      },
      "source": [
        "Based on your deployment scenarios, you may want to pick a different config. Select below your preferred one."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "BfNaTzifGa9x"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Config for official E2B Model\n",
            "\n",
            "Layers Skipped:\n",
            "[20, 21, 22, 23, 24]\n",
            "\n",
            "FFN Hidden Dims:\n",
            "[2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4, 2_048 * 4]\n"
          ]
        }
      ],
      "source": [
        "#@title Config details\n",
        "\n",
        "import ast\n",
        "\n",
        "config_name = \"Config for official E2B Model\"# @param ['Config for official E2B Model', 'Config for E1.96B (layer-level)', 'Config for E2.54B (layer-level)', 'Config for E2.69B (layer-level)', 'Config for E2.98B (layer-level)', 'Config for E3.18B (layer-level)', 'Config for E3.39B (layer-level)', 'Config for E3.59B (layer-level)', 'Config for E3.79B (layer-level)', 'Config for E2.49B (block-level)', 'Config for E2.73B (block-level)', 'Config for E2.98B (block-level)', 'Config for E3.24B (block-level)', 'Config for E3.49B (block-level)', 'Config for E3.79B (block-level)']\n",
        "\n",
        "def safe_string_to_list(value):\n",
        "    \"\"\"\n",
        "    Converts a string representation of a list into a Python list.\n",
        "    - Converts NaN/missing values to an empty list [].\n",
        "    - Uses eval() to handle expressions like '2_048 * 8'.\n",
        "    - Safely handles non-string values by returning them as is.\n",
        "    \"\"\"\n",
        "    # First, check if the value is missing (NaN, None, etc.)\n",
        "    if isinstance(value, list):\n",
        "        return value\n",
        "\n",
        "    # Priority 2: Now that we know it's not a list, check if it's a missing value.\n",
        "    if pd.isna(value):\n",
        "        return []\n",
        "\n",
        "    # Priority 3: If it's a string, try to evaluate it.\n",
        "    if isinstance(value, str):\n",
        "        try:\n",
        "            return eval(value)\n",
        "        except (SyntaxError, NameError):\n",
        "            return value  # Return invalid string as is\n",
        "\n",
        "    # Fallback for any other type (like an integer)\n",
        "    return value\n",
        "\n",
        "df['FFN Hidden Dims List'] = df['FFN Hidden Dims'].apply(safe_string_to_list)\n",
        "df['Layers Skipped'] = df['Layers Skipped'].apply(safe_string_to_list)\n",
        "\n",
        "df_indexed = df.set_index('name')\n",
        "model_row = df_indexed.loc[config_name]\n",
        "\n",
        "layers_to_skip = model_row['Layers Skipped']\n",
        "ffn_hidden_dims = model_row['FFN Hidden Dims List']\n",
        "ffn_hidden_dims_str = model_row['FFN Hidden Dims']\n",
        "\n",
        "print(config_name)\n",
        "print(\"\\nLayers Skipped:\")\n",
        "print(layers_to_skip)\n",
        "print(\"\\nFFN Hidden Dims:\")\n",
        "print(ffn_hidden_dims_str)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zMY2cY6JY15G"
      },
      "source": [
        "If you want to set up your own configuration, please uncomment the following cell and assign values"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Lip1avNY9xd"
      },
      "outputs": [],
      "source": [
        "# Custom config\n",
        "#\n",
        "# layers_to_skip = [] # e.g. [20, 21, 22, 23, 24]\n",
        "# ffn_hidden_dims = [] # e.g. [2048 * 4, ...]\n",
        "# ffn_hidden_dims_str = str(ffn_hidden_dims)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CmZcNRV5IUmk"
      },
      "source": [
        "## Slicing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B7r0gnOaIWO3"
      },
      "source": [
        "### Load the model config and verify slicing configuration\n",
        "\n",
        "Note: we do not load the model at this stage, just verify that the slicing configuration is possible"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O9qsekLDLPiv"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6d8a8117e0ca40ae89d94f9968754e59",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/4.11k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from transformers import AutoConfig, AutoTokenizer\n",
        "\n",
        "original_config = AutoConfig.from_pretrained(original_model_id)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8pgVAGEHA3w4"
      },
      "outputs": [],
      "source": [
        "model_config = original_config.text_config\n",
        "\n",
        "num_layers = model_config.num_hidden_layers\n",
        "final_num_layers = num_layers - len(layers_to_skip)\n",
        "\n",
        "if len(ffn_hidden_dims) != final_num_layers:\n",
        "    raise ValueError(\n",
        "        f\"The length of ffn_hidden_dims ({len(ffn_hidden_dims)}) must be equal \"\n",
        "        f\"to the final number of layers ({final_num_layers}).\"\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pytVtFyqJPfp"
      },
      "source": [
        "### Update configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kz8KQOEEBERT"
      },
      "outputs": [],
      "source": [
        "num_kv_comp_layers = model_config.num_hidden_layers - model_config.num_kv_shared_layers\n",
        "local_kv_sharing_layer_idx = num_kv_comp_layers - 2\n",
        "global_kv_sharing_layer_idx = num_kv_comp_layers - 1\n",
        "\n",
        "if (local_kv_sharing_layer_idx in layers_to_skip or global_kv_sharing_layer_idx in layers_to_skip):\n",
        "  raise ValueError(f'Layers {local_kv_sharing_layer_idx} and {global_kv_sharing_layer_idx} are reserved.')\n",
        "\n",
        "count_kv_sharing = sum(1 for layer in layers_to_skip if layer >= 20)\n",
        "model_config.num_kv_shared_layers -= count_kv_sharing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lXEK9hdyBKA1"
      },
      "outputs": [],
      "source": [
        "count_activation_sparsity = sum(1 for layer in layers_to_skip if layer <= 9)\n",
        "activation_sparsity_list = [0.95] * (10 - count_activation_sparsity) + [0] * (\n",
        "    final_num_layers - 10 + count_activation_sparsity\n",
        ")\n",
        "model_config.activation_sparsity_pattern = activation_sparsity_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "urMZZ2SLBOUr"
      },
      "outputs": [],
      "source": [
        "model_config.num_hidden_layers = final_num_layers\n",
        "model_config.intermediate_size = ffn_hidden_dims"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pjRMQ97hJFtX"
      },
      "source": [
        "### Save the configuration and the unchanged tokenizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BGCI9MqoBP0w"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0e8c4cb7adbe4197b8f781a2110d542d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/1.20M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f0b38a4ccf3c45948a8979a29dd5eb29",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "842f91a11aad47cda2c798a661d19686",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "66b118a3a0a743e5b4f8e0db42dde34d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/769 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e4edeec981504b5ba326d903e7f31577",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "chat_template.jinja:   0%|          | 0.00/1.63k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "New config saved to my_modified_gemma_3n_model\n",
            "Final number of layers: 30\n"
          ]
        }
      ],
      "source": [
        "original_config.save_pretrained(local_output_path)\n",
        "tokenizer = AutoTokenizer.from_pretrained(original_model_id)\n",
        "tokenizer.save_pretrained(local_output_path)\n",
        "\n",
        "print(f\"New config saved to {local_output_path}\")\n",
        "print(f\"Final number of layers: {model_config.num_hidden_layers}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1JlVceVlJVG9"
      },
      "source": [
        "### Load the model checkpoints\n",
        "\n",
        "Note: we are saving the model to disk, so there's no need to have a large CPU/GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5qh_uISmBnVR"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "698fd4b0392546c0aecb414d07a4bd61",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8c6cd9dc712b48e4a5326828f939d99c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "72e223c3fd3942538387393977249db6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f179f9697af948a887556fb84b511b2d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1486ba683c8043c59a3fea1e92c84df2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import os\n",
        "\n",
        "from huggingface_hub import snapshot_download\n",
        "\n",
        "model_path = snapshot_download(original_model_id, allow_patterns=[\"*.safetensors\"])\n",
        "safetensor_files = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3RHu0up8Jp6n"
      },
      "source": [
        "### Slice the model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8I2DvFSKB285"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9ace3835ba99415e9e359472b5f8ef71",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Processing shards:   0%|          | 0/4 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Saving shard model-00001-of-XXXXX.safetensors (size: 4.02 GB)\n",
            "Saving shard model-00002-of-XXXXX.safetensors (size: 4.29 GB)\n"
          ]
        }
      ],
      "source": [
        "from safetensors import safe_open\n",
        "from tqdm.auto import tqdm\n",
        "import re\n",
        "import torch\n",
        "import gc\n",
        "\n",
        "from safetensors.torch import save_file\n",
        "\n",
        "kept_layers_indices = [i for i in range(num_layers) if i not in layers_to_skip]\n",
        "layer_rename_map = {old_idx: new_idx for new_idx, old_idx in enumerate(kept_layers_indices)}\n",
        "\n",
        "# This will store the mapping of tensor names to the file they are saved in\n",
        "weight_map = {}\n",
        "\n",
        "# This will store tensors for the current shard we are building\n",
        "new_shard_state_dict = {}\n",
        "shard_counter = 1\n",
        "total_size = 0\n",
        "\n",
        "pbar = tqdm(total=len(safetensor_files), desc=\"Processing shards\")\n",
        "\n",
        "for shard_path in safetensor_files:\n",
        "    # Open a shard for streaming\n",
        "    with safe_open(shard_path, framework=\"pt\", device=\"cpu\") as f:\n",
        "        # Iterate over each tensor in the shard\n",
        "        for tensor_name in f.keys():\n",
        "            new_tensor_name = tensor_name\n",
        "            tensor = f.get_tensor(tensor_name)\n",
        "\n",
        "            # Case 1: Handle layer-specific parameters\n",
        "            match = re.search(r'\\.layers\\.(\\d+)\\.', tensor_name)\n",
        "            if match:\n",
        "                old_layer_idx = int(match.group(1))\n",
        "\n",
        "                # If this layer is meant to be skipped, we just continue to the next tensor\n",
        "                if old_layer_idx in layers_to_skip:\n",
        "                    continue\n",
        "\n",
        "                # Get the new sequential layer index\n",
        "                new_layer_idx = layer_rename_map[old_layer_idx]\n",
        "                new_tensor_name = tensor_name.replace(\n",
        "                    f'.layers.{old_layer_idx}.',\n",
        "                    f'.layers.{new_layer_idx}.'\n",
        "                )\n",
        "\n",
        "                # Get the target FFN dimension for this new layer\n",
        "                target_ffn_dim = ffn_hidden_dims[new_layer_idx]\n",
        "\n",
        "                # Check if this parameter is part of the FFN and needs slicing\n",
        "                if 'mlp.gate_proj.weight' in new_tensor_name or 'mlp.up_proj.weight' in new_tensor_name:\n",
        "                    # These layers project from model_dim -> ffn_hidden_dim.\n",
        "                    # We slice the output dimension (dim 0).\n",
        "                    tensor = tensor[:target_ffn_dim, :].contiguous()\n",
        "                elif 'mlp.down_proj.weight' in new_tensor_name:\n",
        "                    # This layer projects from ffn_hidden_dim -> model_dim.\n",
        "                    # We slice the input dimension (dim 1).\n",
        "                    tensor = tensor[:, :target_ffn_dim].contiguous()\n",
        "\n",
        "            # Case 2: Handle special non-layer parameters that need slicing\n",
        "            elif 'per_layer_model_projection' in tensor_name:\n",
        "                # Reshape, slice based on kept layers, and reshape back\n",
        "                reshaped_params = tensor.reshape((num_layers, tensor.shape[0] // num_layers, tensor.shape[1]))\n",
        "                tensor = reshaped_params[kept_layers_indices, :, :]\n",
        "                tensor = tensor.reshape(-1, tensor.shape[-1]).contiguous()\n",
        "\n",
        "            elif 'embed_tokens_per_layer' in tensor_name:\n",
        "                # Reshape, slice based on kept layers, and reshape back\n",
        "                reshaped_params = tensor.reshape((tensor.shape[0], num_layers, tensor.shape[1] // num_layers))\n",
        "                tensor = reshaped_params[:, kept_layers_indices, :]\n",
        "                tensor = tensor.reshape(tensor.shape[0], -1).contiguous()\n",
        "\n",
        "            # Add the (potentially modified) tensor to the new shard\n",
        "            new_shard_state_dict[new_tensor_name] = tensor\n",
        "\n",
        "            # Check if the current shard is getting too big\n",
        "            current_shard_size = sum(t.numel() * t.element_size() for t in new_shard_state_dict.values())\n",
        "            if current_shard_size > 4000000000: # Create new shard if current is over 4GB\n",
        "                shard_filename = f\"model-{(shard_counter):05d}-of-XXXXX.safetensors\"\n",
        "                print(f\"Saving shard {shard_filename} (size: {current_shard_size / 1e9:.2f} GB)\")\n",
        "                save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename), metadata={'format': 'pt'})\n",
        "\n",
        "                # Record which tensors are in this shard\n",
        "                for k in new_shard_state_dict.keys():\n",
        "                    weight_map[k] = os.path.basename(shard_filename)\n",
        "\n",
        "                # Reset for the next shard\n",
        "                shard_counter += 1\n",
        "                new_shard_state_dict = {}\n",
        "                gc.collect() # Free up memory\n",
        "    pbar.update(1)\n",
        "\n",
        "pbar.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FgN1yAQ-CFkE"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Saving final shard model-00003-of-XXXXX.safetensors\n"
          ]
        }
      ],
      "source": [
        "# Save any remaining tensors in the last shard\n",
        "if new_shard_state_dict:\n",
        "    shard_filename = f\"model-{(shard_counter):05d}-of-XXXXX.safetensors\"\n",
        "    print(f\"Saving final shard {shard_filename}\")\n",
        "    save_file(new_shard_state_dict, os.path.join(local_output_path, shard_filename), metadata={'format': 'pt'})\n",
        "    for k in new_shard_state_dict.keys():\n",
        "        weight_map[k] = os.path.basename(shard_filename)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1ZuSMTulGUPv"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "300"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "del new_shard_state_dict\n",
        "gc.collect()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wcEWpUKOGO4A"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "--- 3. Finalizing Model Save ---\n",
            "\n",
            "✅ Model slicing complete. New model saved in: my_modified_gemma_3n_model\n"
          ]
        }
      ],
      "source": [
        "import json\n",
        "print(\"\\n--- 3. Finalizing Model Save ---\")\n",
        "\n",
        "# The total number of shards we created\n",
        "num_shards = shard_counter\n",
        "\n",
        "# Update the \"XXXXX\" in the filenames to the correct total number of shards\n",
        "for i in range(1, num_shards + 1):\n",
        "    old_filename = f\"model-{(i):05d}-of-XXXXX.safetensors\"\n",
        "    new_filename = f\"model-{(i):05d}-of-{(num_shards):05d}.safetensors\"\n",
        "\n",
        "    # Rename the file\n",
        "    os.rename(os.path.join(local_output_path, old_filename), os.path.join(local_output_path, new_filename))\n",
        "\n",
        "    # Update the weight_map to point to the new filename\n",
        "    for k, v in weight_map.items():\n",
        "        if v == old_filename:\n",
        "            weight_map[k] = new_filename\n",
        "\n",
        "# Create and save the index.json file\n",
        "index_json = {\n",
        "    \"metadata\": {\n",
        "        \"total_size\": sum(os.path.getsize(os.path.join(local_output_path, f)) for f in os.listdir(local_output_path) if f.endswith('.safetensors'))\n",
        "    },\n",
        "    \"weight_map\": weight_map\n",
        "}\n",
        "\n",
        "with open(os.path.join(local_output_path, \"model.safetensors.index.json\"), \"w\") as f:\n",
        "    json.dump(index_json, f, indent=2)\n",
        "\n",
        "print(f\"\\n✅ Model slicing complete. New model saved in: {local_output_path}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x1khD_77N8YX"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "560b0e55cf384c0488875e972bca52ca",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/23.4k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prepended custom description to the model card content.\n",
            "New README.md saved to 'my_modified_gemma_3n_model/README.md'\n"
          ]
        }
      ],
      "source": [
        "from huggingface_hub import  ModelCard, ModelCardData\n",
        "\n",
        "card = ModelCard.load(original_model_id)\n",
        "card.data.base_model = original_model_id\n",
        "del card.data.extra_gated_heading\n",
        "del card.data.extra_gated_prompt\n",
        "card.data.tags.append(\"matformer\")\n",
        "\n",
        "new_description = f\"\"\"\n",
        "> [!Note]\n",
        "> This is a submodel derived from `{original_model_id}`. It has been modified by slicing specific layers and resizing FFN dimensions. It is not the original model.\n",
        "> To learn more about MatFormers, please review the [launch blog](https://developers.googleblog.com/en/introducing-gemma-3n-developer-guide) and generate your own submodels\n",
        "with the [MatFormer Lab](https://goo.gle/gemma3n-matformer-lab).\n",
        ">\n",
        "\n",
        "Skipped layers: {layers_to_skip}\n",
        "\n",
        "FFN hidden dimensions: {ffn_hidden_dims_str}\n",
        "\"\"\"\n",
        "\n",
        "card.text = new_description + \"\\n\" + card.text\n",
        "print(\"Prepended custom description to the model card content.\")\n",
        "\n",
        "new_readme_path = os.path.join(local_output_path, \"README.md\")\n",
        "card.save(new_readme_path)\n",
        "print(f\"New README.md saved to '{new_readme_path}'\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BuF3U6nhOFwu"
      },
      "source": [
        "## Push the model to Hugging Face"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9hamNb9aGkjP"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import HfApi\n",
        "\n",
        "print(f\"Creating private repository: {push_hf_repo_id}\")\n",
        "\n",
        "# Instantiate the HfApi client\n",
        "api = HfApi()\n",
        "\n",
        "# Create a new private repository on the Hub.\n",
        "repo_url = api.create_repo(\n",
        "    repo_id=push_hf_repo_id,\n",
        "    private=True,\n",
        "    exist_ok=True\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0OEIMgDQHCUM"
      },
      "outputs": [],
      "source": [
        "print(f\"Uploading files from '{local_output_path}' to '{push_hf_repo_id}'...\")\n",
        "api.upload_folder(\n",
        "    folder_path=local_output_path,\n",
        "    repo_id=push_hf_repo_id,\n",
        "    repo_type=\"model\",\n",
        "    commit_message=\"Upload sliced model checkpoint\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "4Gf-3Tg2TaJs"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "db433ba38c874916a76d23eb45ff0ac2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu and disk.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Total Parameters: 5,976,833,408\n",
            "Total Text Parameters: 4,456,156,768\n",
            "Effective Parameters (excluding vision, audio, and Per-Layer-Embeddings): 1,905,495,648\n"
          ]
        }
      ],
      "source": [
        "#@title Verify new model can be loaded\n",
        "\n",
        "\n",
        "from transformers import AutoModelForCausalLM\n",
        "import torch\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(push_hf_repo_id, torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
        "\n",
        "print(f\"Total Parameters: {model.num_parameters():,}\") # 5,976,833,408\n",
        "print(f\"Total Text Parameters: {model.language_model.num_parameters():,}\") # 4,456,156,768\n",
        "print(f\"Effective Parameters (excluding vision, audio, and Per-Layer-Embeddings): {model.language_model.num_parameters(exclude_embeddings=True):,}\") # 1,905,495,648"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kO76mJDtOAr7"
      },
      "source": [
        "## Citation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGDUTwIkKrnp"
      },
      "source": [
        "```\n",
        "@article{gemma_3n_2025,\n",
        "    title={Gemma 3n MatFormer Lab},\n",
        "    url={https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/%5BGemma_3n%5DMatFormer_Lab.ipynb},\n",
        "    publisher={Google DeepMind},\n",
        "    author={Puranjay Datta, Aditya Kusupati, Ryan Mullins, Omar Sanseviero, Rakesh Shivanna, Gemma Team},\n",
        "    year={2025}\n",
        "}\n",
        "```"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "B7r0gnOaIWO3",
        "pytVtFyqJPfp",
        "pjRMQ97hJFtX",
        "1JlVceVlJVG9",
        "3RHu0up8Jp6n"
      ],
      "name": "[Gemma_3n]MatFormer_Lab.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
